from collections.abc import Sequence
from typing import Literal, Optional, Union, cast
from uuid import uuid4

from django.utils import timezone

from langchain_core.messages import (
    AIMessage as LangchainAIMessage,
    BaseMessage,
    HumanMessage as LangchainHumanMessage,
    ToolMessage as LangchainToolMessage,
)
from langchain_core.output_parsers import PydanticToolsParser, StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langgraph.errors import NodeInterrupt
from pydantic import BaseModel, Field, ValidationError

from posthog.schema import (
    AssistantForm,
    AssistantFormOption,
    AssistantMessage,
    AssistantMessageMetadata,
    CachedEventTaxonomyQueryResponse,
    CacheMissResponse,
    ContextMessage,
    EventTaxonomyItem,
    EventTaxonomyQuery,
    HumanMessage,
    QueryStatusResponse,
)

from posthog.event_usage import report_user_action
from posthog.hogql_queries.ai.event_taxonomy_query_runner import EventTaxonomyQueryRunner
from posthog.hogql_queries.query_runner import ExecutionMode
from posthog.sync import database_sync_to_async
from posthog.utils import human_list

from ee.hogai.artifacts.utils import unwrap_visualization_artifact_content
from ee.hogai.core.agent_modes import SlashCommandName
from ee.hogai.core.mixins import AssistantContextMixin
from ee.hogai.core.node import AssistantNode
from ee.hogai.llm import MaxChatOpenAI
from ee.hogai.utils.helpers import filter_and_merge_messages, find_last_message_of_type
from ee.hogai.utils.markdown import remove_markdown
from ee.hogai.utils.prompt import format_prompt_string
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import ArtifactRefMessage
from ee.models.assistant import CoreMemory

from .parsers import MemoryCollectionCompleted, compressed_memory_parser, raise_memory_updated
from .prompts import (
    ENQUIRY_INITIAL_MESSAGE,
    INITIALIZE_CORE_MEMORY_SYSTEM_PROMPT,
    INITIALIZE_CORE_MEMORY_WITH_BUNDLE_IDS_USER_PROMPT,
    INITIALIZE_CORE_MEMORY_WITH_DOMAINS_USER_PROMPT,
    MEMORY_COLLECTOR_PROMPT,
    MEMORY_COLLECTOR_WITH_VISUALIZATION_PROMPT,
    MEMORY_INITIALIZED_CONTEXT_PROMPT,
    MEMORY_ONBOARDING_ENQUIRY_PROMPT,
    ONBOARDING_COMPRESSION_PROMPT,
    SCRAPING_CONFIRMATION_MESSAGE,
    SCRAPING_INITIAL_MESSAGE,
    SCRAPING_REJECTION_MESSAGE,
    SCRAPING_SUCCESS_KEY_PHRASE,
    SCRAPING_TERMINATION_MESSAGE,
    SCRAPING_VERIFICATION_MESSAGE,
    TOOL_CALL_ERROR_PROMPT,
)


class MemoryInitializerContextMixin(AssistantContextMixin):
    async def _aretrieve_context(self, *, config: RunnableConfig | None = None) -> EventTaxonomyItem | None:
        if config and "_mock_memory_onboarding_context" in config.get("configurable", {}):
            # Only for evals/tests (as patch() doesn't work because of evals running concurrently async)
            return config["configurable"]["_mock_memory_onboarding_context"]
        # Retrieve the origin domain.
        response = await self._arun_taxonomy_query("$pageview", "$host")
        if not isinstance(response, CachedEventTaxonomyQueryResponse):
            raise ValueError("Failed to query the event taxonomy.")
        # Otherwise, retrieve the app bundle ID.
        if not response.results:
            response = await self._arun_taxonomy_query("$screen", "$app_namespace")
        if not isinstance(response, CachedEventTaxonomyQueryResponse):
            raise ValueError("Failed to query the event taxonomy.")
        if not response.results:
            return None
        item = response.results[0]
        # Exclude localhost from sample values. We could maybe do it at the query level.
        # Note: This means if there are $host values but only localhost, we will miss out on any potential $app_namespace values
        # This is probably okay - it's just simpler code-wise here
        item.sample_values = [
            v
            for v in item.sample_values
            if v != "localhost" and not v.startswith("localhost:") and not v.startswith("127.0.0.1")
        ]
        if not item.sample_values:
            return None
        return item

    async def _arun_taxonomy_query(
        self, event: str, property: str
    ) -> CacheMissResponse | QueryStatusResponse | CachedEventTaxonomyQueryResponse:
        def run_query() -> CacheMissResponse | QueryStatusResponse | CachedEventTaxonomyQueryResponse:
            runner = EventTaxonomyQueryRunner(
                team=self._team, query=EventTaxonomyQuery(event=event, properties=[property])
            )
            return runner.run(ExecutionMode.RECENT_CACHE_CALCULATE_ASYNC_IF_STALE_AND_BLOCKING_ON_MISS)

        return await database_sync_to_async(run_query, thread_sensitive=False)()


class MemoryOnboardingShouldRunMixin(AssistantNode):
    async def should_run_onboarding_at_start(self, state: AssistantState) -> Literal["continue", "memory_onboarding"]:
        """
        Only trigger memory onboarding when explicitly requested with /init command.
        If another user has already started the onboarding process, or it has already been completed, do not trigger it again.
        If no messages are to be found in the AssistantState, do not run onboarding.
        """
        core_memory = await self._aget_core_memory()

        if core_memory and core_memory.is_scraping_pending:
            # a user has already started the onboarding, we don't allow other users to start it concurrently until timeout is reached
            return "continue"

        if not state.messages:
            return "continue"

        last_message = state.messages[-1]
        if isinstance(last_message, HumanMessage) and last_message.content.startswith("/"):
            await database_sync_to_async(report_user_action)(
                self._user, "Max slash command used", {"slash_command": last_message.content}, team=self._team
            )
        if isinstance(last_message, HumanMessage) and last_message.content == SlashCommandName.FIELD_INIT:
            return "memory_onboarding"
        return "continue"


class MemoryOnboardingNode(MemoryInitializerContextMixin, MemoryOnboardingShouldRunMixin):
    async def arun(self, state: AssistantState, config: RunnableConfig) -> Optional[PartialAssistantState]:
        core_memory, _ = await CoreMemory.objects.aget_or_create(team=self._team)
        await core_memory.achange_status_to_pending()

        # The team has a product description, initialize the memory with it.
        if product_description := await database_sync_to_async(lambda: self._team.project.product_description)():
            await core_memory.aappend_question_to_initial_text("What does the company do?")
            await core_memory.aappend_answer_to_initial_text(product_description)
            return PartialAssistantState(
                messages=[
                    AssistantMessage(
                        content=ENQUIRY_INITIAL_MESSAGE,
                        id=str(uuid4()),
                    )
                ]
            )

        retrieved_prop = await self._aretrieve_context(config=config)

        # No host or app bundle ID found
        if not retrieved_prop:
            return PartialAssistantState(
                messages=[
                    AssistantMessage(
                        content=ENQUIRY_INITIAL_MESSAGE,
                        id=str(uuid4()),
                    )
                ]
            )

        return PartialAssistantState(
            messages=[
                AssistantMessage(
                    content=SCRAPING_INITIAL_MESSAGE.format(
                        domains_or_bundle_ids_formatted=human_list([f"`{x}`" for x in retrieved_prop.sample_values])
                    ),
                    id=str(uuid4()),
                )
            ]
        )


class MemoryInitializerNode(MemoryInitializerContextMixin, AssistantNode):
    """
    Scrapes the product description from the given origin or app bundle IDs.
    """

    async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
        core_memory, _ = await CoreMemory.objects.aget_or_create(team=self._team)
        if core_memory.initial_text:
            # Reset the initial text if it's not the first time /init is ran
            core_memory.initial_text = ""
            await core_memory.asave()
        retrieved_prop = await self._aretrieve_context(config=config)
        # No host or app bundle ID found, continue.
        if not retrieved_prop:
            return None

        prompt = ChatPromptTemplate.from_messages(
            [("system", INITIALIZE_CORE_MEMORY_SYSTEM_PROMPT)], template_format="mustache"
        )
        if retrieved_prop.property == "$host":
            prompt += ChatPromptTemplate.from_messages(
                [("human", INITIALIZE_CORE_MEMORY_WITH_DOMAINS_USER_PROMPT)], template_format="mustache"
            ).partial(domains=",".join(retrieved_prop.sample_values))
        else:
            prompt += ChatPromptTemplate.from_messages(
                [("human", INITIALIZE_CORE_MEMORY_WITH_BUNDLE_IDS_USER_PROMPT)], template_format="mustache"
            ).partial(bundle_ids=", ".join(retrieved_prop.sample_values))

        chain = prompt | self._model() | StrOutputParser()
        answer = await chain.ainvoke({}, config=config)
        # The model has failed to scrape the data, continue.
        if answer == SCRAPING_TERMINATION_MESSAGE:
            return PartialAssistantState(messages=[AssistantMessage(content=answer, id=str(uuid4()))])
        # Otherwise, proceed to confirmation that the memory is correct.
        await core_memory.aappend_question_to_initial_text("What does the company do?")
        await core_memory.aappend_answer_to_initial_text(answer)
        return PartialAssistantState(messages=[AssistantMessage(content=answer, id=str(uuid4()))])

    def router(self, state: AssistantState) -> Literal["interrupt", "continue"]:
        last_message = state.messages[-1]
        if isinstance(last_message, AssistantMessage) and SCRAPING_SUCCESS_KEY_PHRASE in last_message.content:
            return "interrupt"
        return "continue"

    def _model(self):
        return MaxChatOpenAI(
            model="gpt-5-mini",
            streaming=True,
            use_responses_api=True,
            store=False,  # We can't store, because we want zero data retention
            reasoning={
                "summary": "auto",  # Without this, there's no reasoning summaries! Only works with reasoning models
            },
            user=self._user,
            team=self._team,
        ).bind_tools([{"type": "web_search"}])


class MemoryInitializerInterruptNode(AssistantNode):
    """
    Prompts the user to confirm or reject the scraped memory.
    """

    async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
        raise NodeInterrupt(
            AssistantMessage(
                content=SCRAPING_VERIFICATION_MESSAGE,
                meta=AssistantMessageMetadata(
                    form=AssistantForm(
                        options=[
                            AssistantFormOption(value=SCRAPING_CONFIRMATION_MESSAGE, variant="primary"),
                            AssistantFormOption(value=SCRAPING_REJECTION_MESSAGE),
                        ]
                    )
                ),
                id=str(uuid4()),
            )
        )


class MemoryOnboardingEnquiryNode(AssistantNode):
    """
    Prompts the user to give more information about the product, feature, business, etc.
    """

    async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
        human_message = find_last_message_of_type(state.messages, HumanMessage)
        if not human_message:
            raise ValueError("No human message found.")

        core_memory, _ = await CoreMemory.objects.aget_or_create(team=self._team)
        if (
            human_message.content not in [SCRAPING_CONFIRMATION_MESSAGE, SCRAPING_REJECTION_MESSAGE]
            and not human_message.content.startswith("/")  # Ignore slash commands
            and core_memory.initial_text.endswith("Answer:")
        ):
            # The user is answering to a question
            await core_memory.aappend_answer_to_initial_text(human_message.content)

        if human_message.content == SCRAPING_REJECTION_MESSAGE:
            core_memory.initial_text = ""
            await core_memory.asave()

        answers_left = core_memory.answers_left
        if answers_left > 0:
            prompt = ChatPromptTemplate.from_messages(
                [("system", MEMORY_ONBOARDING_ENQUIRY_PROMPT)], template_format="mustache"
            ).partial(core_memory=core_memory.initial_text, questions_left=answers_left)

            chain = prompt | self._model | StrOutputParser()
            response = await chain.ainvoke({}, config=config)

            if "[Done]" not in response and "===" in response:
                question = self._format_question(response)
                await core_memory.aappend_question_to_initial_text(question)
                return PartialAssistantState(onboarding_question=question)
        return PartialAssistantState(onboarding_question=None, answers_left=None)

    @property
    def _model(self):
        return MaxChatOpenAI(
            model="gpt-4.1",
            temperature=0.3,
            disable_streaming=True,
            stop_sequences=["[Done]"],
            user=self._user,
            team=self._team,
        )

    async def arouter(self, state: AssistantState) -> Literal["continue", "interrupt"]:
        core_memory, _ = await CoreMemory.objects.aget_or_create(team=self._team)
        if state.onboarding_question and core_memory.answers_left > 0:
            return "interrupt"
        return "continue"

    def _format_question(self, question: str) -> str:
        if "===" in question:
            question = question.split("===")[1]
        return remove_markdown(question)


class MemoryOnboardingEnquiryInterruptNode(AssistantNode):
    async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
        last_assistant_message = find_last_message_of_type(state.messages, AssistantMessage)
        if not state.onboarding_question:
            raise ValueError("No onboarding question found.")
        if last_assistant_message and last_assistant_message.content != state.onboarding_question:
            raise NodeInterrupt(AssistantMessage(content=state.onboarding_question, id=str(uuid4())))
        return PartialAssistantState(onboarding_question=None)


class MemoryOnboardingFinalizeNode(AssistantNode):
    async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
        core_memory, _ = await CoreMemory.objects.aget_or_create(team=self._team)
        # Compress the question/answer memory before saving it
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", ONBOARDING_COMPRESSION_PROMPT),
                ("human", "{memory_content}"),
            ]
        )
        chain = prompt | self._model | StrOutputParser() | compressed_memory_parser
        compressed_memory = cast(str, await chain.ainvoke({"memory_content": core_memory.initial_text}, config=config))
        compressed_memory = compressed_memory.replace("\n", " ").strip()
        await core_memory.aset_core_memory(compressed_memory)

        context_message = ContextMessage(
            content=format_prompt_string(MEMORY_INITIALIZED_CONTEXT_PROMPT, core_memory=core_memory.initial_text),
            id=str(uuid4()),
        )
        return PartialAssistantState(
            messages=[context_message],
            start_id=context_message.id,
            root_conversation_start_id=context_message.id,
        )

    @property
    def _model(self):
        return MaxChatOpenAI(
            model="gpt-4.1-mini",
            temperature=0.3,
            disable_streaming=True,
            stop_sequences=["[Done]"],
            user=self._user,
            team=self._team,
        )


# Lower casing matters here. Do not change it.
class core_memory_append(BaseModel):
    """
    Appends a new memory fragment to persistent storage.
    """

    memory_content: str = Field(description="The content of a new memory to be added to storage.")


# Lower casing matters here. Do not change it.
class core_memory_replace(BaseModel):
    """
    Replaces a specific fragment of memory (word, sentence, paragraph, etc.) with another in persistent storage.
    """

    original_fragment: str = Field(description="The content of the memory to be replaced.")
    new_fragment: str = Field(description="The content to replace the existing memory with.")


memory_collector_tools = [core_memory_append, core_memory_replace]


class MemoryCollectorNode(MemoryOnboardingShouldRunMixin):
    """
    The Memory Collector manages the core memory of the agent. Core memory is a text containing facts about a user's company and product. It helps the agent save and remember facts that could be useful for insight generation or other agentic functions requiring deeper context about the product.
    """

    async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
        if await self.should_run_onboarding_at_start(state) != "continue":
            return None

        # Check if an interrupt had an unhandled tool, so it should go to tools first.
        node_messages = state.memory_collection_messages or []
        if not self._check_tool_messages_are_valid(node_messages):
            return None

        # Check if the last message is a /remember command
        remember_command_result = self._handle_remember_command(state)
        if remember_command_result:
            return PartialAssistantState(memory_collection_messages=[remember_command_result])

        prompt = ChatPromptTemplate.from_messages(
            [("system", MEMORY_COLLECTOR_PROMPT)], template_format="mustache"
        ) + await self._aconstruct_messages(state)
        chain = prompt | self._model | raise_memory_updated

        try:
            response = await chain.ainvoke(
                {
                    "core_memory": await self._aget_core_memory_text(),
                    "date": timezone.now().strftime("%Y-%m-%d"),
                },
                config=config,
            )
        except MemoryCollectionCompleted:
            return PartialAssistantState(memory_collection_messages=None)
        return PartialAssistantState(memory_collection_messages=[*node_messages, cast(LangchainAIMessage, response)])

    def router(self, state: AssistantState) -> Literal["tools", "next"]:
        if not state.memory_collection_messages:
            return "next"
        return "tools"

    @property
    def _model(self):
        return MaxChatOpenAI(
            model="gpt-4.1", temperature=0.3, disable_streaming=True, user=self._user, team=self._team, billable=True
        ).bind_tools(memory_collector_tools)

    async def _aconstruct_messages(self, state: AssistantState) -> list[BaseMessage]:
        node_messages = state.memory_collection_messages or []

        filtered_messages = filter_and_merge_messages(
            state.messages, entity_filter=(HumanMessage, AssistantMessage, ArtifactRefMessage)
        )
        enriched_messages = await self.context_manager.artifacts.aenrich_messages(filtered_messages)

        conversation: list[BaseMessage] = []
        for message in enriched_messages:
            if isinstance(message, HumanMessage):
                conversation.append(LangchainHumanMessage(content=message.content, id=message.id))
            elif isinstance(message, AssistantMessage):
                conversation.append(LangchainAIMessage(content=message.content, id=message.id))
            elif content := unwrap_visualization_artifact_content(message):
                schema = content.query.model_dump_json(exclude_unset=True, exclude_none=True)
                conversation += ChatPromptTemplate.from_messages(
                    [
                        ("assistant", MEMORY_COLLECTOR_WITH_VISUALIZATION_PROMPT),
                    ],
                    template_format="mustache",
                ).format_messages(
                    schema=schema,
                )

        # Trim messages to keep only last 10 messages.
        messages = [*conversation[-10:], *node_messages]
        return messages

    def _handle_remember_command(self, state: AssistantState) -> LangchainAIMessage | None:
        last_message = state.messages[-1] if state.messages else None
        if (
            not isinstance(last_message, HumanMessage)
            or not last_message.content.split(" ", 1)[0] == SlashCommandName.FIELD_REMEMBER
        ):
            # Not a /remember command, skip!
            return None

        # Extract the content to remember (everything after "/remember ")
        remember_content = last_message.content[len(SlashCommandName.FIELD_REMEMBER) :].strip()
        if remember_content:
            # Create a direct memory append tool call
            return LangchainAIMessage(
                content="I'll remember that for you.",
                tool_calls=[
                    {"id": str(uuid4()), "name": "core_memory_append", "args": {"memory_content": remember_content}}
                ],
                id=str(uuid4()),
            )
        else:
            return LangchainAIMessage(content="There's nothing to remember!", id=str(uuid4()))

    def _check_tool_messages_are_valid(self, messages: Sequence[BaseMessage]) -> bool:
        """Validates that all AIMessages have associated ToolCall messages."""
        mapping = {message.tool_call_id: message for message in messages if isinstance(message, LangchainToolMessage)}
        tool_ids: set[str] = set()
        for message in messages:
            if isinstance(message, LangchainAIMessage):
                tool_ids.update(tool["id"] for tool in message.tool_calls if tool["id"] is not None)
        return set(mapping.keys()) == tool_ids


class MemoryCollectorToolsNode(AssistantNode):
    async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
        node_messages = state.memory_collection_messages
        if not node_messages:
            raise ValueError("No memory collection messages found.")
        last_message = node_messages[-1]
        if not isinstance(last_message, LangchainAIMessage):
            raise ValueError("Last message must be an AI message.")
        core_memory, _ = await CoreMemory.objects.aget_or_create(team=self._team)

        tools_parser = PydanticToolsParser(tools=memory_collector_tools)
        try:
            tool_calls: list[Union[core_memory_append, core_memory_replace]] = await tools_parser.ainvoke(
                last_message, config=config
            )
        except ValidationError as e:
            failover_messages = ChatPromptTemplate.from_messages(
                [("user", TOOL_CALL_ERROR_PROMPT)], template_format="mustache"
            ).format_messages(validation_error_message=e.errors(include_url=False))
            return PartialAssistantState(
                memory_collection_messages=[*node_messages, *failover_messages],
            )

        new_messages: list[LangchainToolMessage] = []
        for tool_call, schema in zip(last_message.tool_calls, tool_calls):
            if isinstance(schema, core_memory_append):
                await core_memory.aappend_core_memory(schema.memory_content)
                new_messages.append(LangchainToolMessage(content="Memory appended.", tool_call_id=tool_call["id"]))
            if isinstance(schema, core_memory_replace):
                try:
                    await core_memory.areplace_core_memory(schema.original_fragment, schema.new_fragment)
                    new_messages.append(LangchainToolMessage(content="Memory replaced.", tool_call_id=tool_call["id"]))
                except ValueError as e:
                    new_messages.append(LangchainToolMessage(content=str(e), tool_call_id=tool_call["id"]))

        return PartialAssistantState(
            memory_collection_messages=[*node_messages, *new_messages],
        )
